from __future__ import annotations
import os
import imageio
import pickle
import numpy as np
import json
from graphviz import Digraph
from PIL import Image, ImageOps


class TreeNode:
    def __init__(self, label, image=None, image_path=None, parent=None) -> None:
        self.label = label
        self.image = image
        self.image_path = os.path.abspath(image_path) if image_path is not None else None
        self.children = []
        self.parent = parent if parent is not None else self
        self.query_image = None
        self.query_image_path = None
        self.caption = ''

    def add_child(self, child_node : TreeNode) -> None:
        self.children.append(child_node)

    def set_caption(self, caption) -> None:
        self.caption = caption

    def set_image(self, image, image_path) -> None:
        self.image = image
        self.image_path = image_path

    def set_embed(self, embed) -> None:
        self.embed = embed

class PartTree:
    def __init__(self, output_dir, label, image=None, image_path=None) -> None:
        self.output_dir = output_dir
        self.root = TreeNode(label, image, image_path)
        self.label_to_node = {label: self.root}

    def get_nodes(self) -> list:
        return list(self.label_to_node.values())
    
    def get_leaves(self) ->list:
        node_list = self.get_nodes()
        return [node for node in node_list if not node.children]


    def set_node_image(self, label, image, image_path) -> None:
        assert self.exists(label), f"Node {label} does not exist"
        node = self.label_to_node[label]
        node.set_image(image, image_path)

    def set_node_embed(self, label, embed) -> None:
        assert self.exists(label), f"Node {label} does not exist"
        node = self.label_to_node[label]
        node.set_embed(embed)

    def add_edge(self, child_label, image=None, image_path=None, parent_label=None) -> None:
        if parent_label is None:
            parent = self.root
        else:
            if parent_label not in self.label_to_node:
                raise KeyError(f"Label {parent_label} not in tree")
            parent = self.label_to_node[parent_label]

        if self.exists(child_label):
            raise ValueError(f"Label of {child_label} already exists.")
        child_node = TreeNode(child_label, image, image_path, parent=parent)
        parent.add_child(child_node)
        self.label_to_node[child_label] = child_node

    def exists(self, label) -> bool:
        return label in self.label_to_node

    def canonize_tree(self) -> None:
        """
        Consolidate nodes with single children
        """
        def canonize(node):
            while len(node.children) == 1:
                self.label_to_node.pop(node.children[0].label)
                node.children = node.children[0].children
            for child in node.children:
                child.parent = node
                canonize(child)
        canonize(self.root)

    def add_graphviz_edge(self, node, graph):
        assert node.image is not None, f"Image not defined for node {node.label}"
        assert node.image_path is not None, f"Image path not defined for node {node.label}"
        if not os.path.exists(node.image_path):
            os.makedirs(os.path.dirname(node.image_path), exist_ok=True)

        if node.query_image is not None:
            ip = node.query_image_path
            i = node.query_image
        else:
            ip = node.image_path
            i = node.image

        imageio.imsave(ip, i)

        # Use HTML-like label to embed images in nodes
        graph.node(str(node.label), label=node.caption, image=ip, 
                   imagepos='tc', labelloc='b', shape='plaintext', fontcolor='green')
        for child in node.children:
            graph.edge(str(node.label), str(child.label))
            self.add_graphviz_edge(child, graph)

    def render_tree(self, output_file) -> None:
        """
        Render the entire part tree using the provided images
        """
        os.makedirs(self.output_dir, exist_ok=True)

        dot = Digraph(comment='Tree with Images', format='pdf', engine='dot')
        dot.attr(size='100,100')  # Size can be adjusted
        self.add_graphviz_edge(self.root, dot)
        dot.render(output_file, view=False)  # This creates a PNG and opens it

    def query_preprocess(self, **kwargs) -> None:
        """
        Create query images to be fed to image encoder
        Query images consist of the parent image, with the query segment masked
        with a blue overlay
        """
        def make_query_image(node: TreeNode, **kwargs) -> None:
            assert node.image is not None, f"Image not defined for node {node.label}"
            assert node.image_path is not None, f"Image path not defined for node {node.label}"
            # node_img = imageio.imread(node.image_path)
            node_img = node.image
            # node_img = Image.fromarray(node_img).convert("RGBA")
            node_img = Image.fromarray(node_img).convert("RGB")
            # if node is self.root:
            #     node.query_image = node_img
            # else:
            # # parent_img = imageio.imread(node.parent.image_path)
            # parent_img = self.root.query_image
            # parent_img = parent_img.convert("RGBA")
            #
            # mask = node_img.convert("L")
            # mask = mask.point(lambda x : 128 if x > 0 else 0, 'L')
            #
            # overlay = Image.new("RGBA", node_img.size, (0, 0, 255, 255))
            # query_img = Image.composite(overlay, parent_img, mask)

            query_img = node_img

            if 'crop' in kwargs:
                bbox = node_img.getbbox()
                query_img = query_img.crop(bbox)

            if 'pad' in kwargs:
                h, w = query_img.size
                l = max(w,h)
                size = (l, l)
                query_img = ImageOps.pad(query_img, size, color=(0, 0, 0))

            if 'resize' in kwargs:
                size = kwargs['resize']
                query_img = query_img.resize(size)

            query_img = np.array(query_img)
            node.query_image = query_img
            node.query_image_path = node.image_path[:-4] + '_query.png'

            for child in node.children:
                make_query_image(child, **kwargs)


        make_query_image(self.root, **kwargs)
        
    def format_json(self):
        """
        Format the tree as a json
        """
        def tree_to_dict(node):
            node_dict = {"label": node.label,
                         "image": node.query_image}
            if node.children:
                node_dict["children"] = []
                for child in node.children:
                    node_dict["children"].append(tree_to_dict(child))
            return node_dict
        
        tree_dict = tree_to_dict(self.root)
        return tree_dict



def save_tree(tree, output_path) -> None:
    with open(output_path, 'wb') as f:
        pickle.dump(tree, f)
    print(f'Tree saved to {output_path}')

def load_tree(path) -> PartTree:
    with open(path, 'rb') as f:
        tree = pickle.load(f)
    return tree



